"""Module of base classes and helper methods for imitation learning algorithms."""

import abc
from typing import Any, Generic, Iterable, Mapping, Optional, TypeVar, Union

import numpy as np
import torch as th
import torch.utils.data as th_data
from stable_baselines3.common import policies

from imitation.data import rollout, types
from imitation.util import logger as imit_logger


class BaseImitationAlgorithm(abc.ABC):
    """Base class for all imitation learning algorithms."""

    _logger: imit_logger.HierarchicalLogger
    """Object to log statistics and natural language messages to."""

    allow_variable_horizon: bool
    """If True, allow variable horizon trajectories; otherwise error if detected."""

    _horizon: Optional[int]
    """Horizon of trajectories seen so far (None if no trajectories seen)."""

    def __init__(
        self,
        *,
        custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
        allow_variable_horizon: bool = False,
    ):
        """Creates an imitation learning algorithm.

        Args:
            custom_logger: Where to log to; if None (default), creates a new logger.
            allow_variable_horizon: If False (default), algorithm will raise an
                exception if it detects trajectories of different length during
                training. If True, overrides this safety check. WARNING: variable
                horizon episodes leak information about the reward via termination
                condition, and can seriously confound evaluation. Read
                https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html
                before overriding this.
        """
        self._logger = custom_logger or imit_logger.configure()
        self.allow_variable_horizon = allow_variable_horizon
        if allow_variable_horizon:
            self.logger.warn(
                "Running with `allow_variable_horizon` set to True. "
                "Some algorithms are biased towards shorter or longer "
                "episodes, which may significantly confound results. "
                "Additionally, even unbiased algorithms can exploit "
                "the information leak from the termination condition, "
                "producing spuriously high performance. See "
                "https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html"
                " for more information.",
            )
        self._horizon = None

    @property
    def logger(self):
        return self._logger

    @logger.setter
    def logger(self, value: imit_logger.HierarchicalLogger) -> None:
        self._logger = value

    def _check_fixed_horizon(self, horizons: Iterable[int]) -> None:
        """Checks that episode lengths in `horizons` are fixed and equal to prior calls.

        If algorithm is safe to use with variable horizon episodes (e.g. behavioral
        cloning), then just don't call this method.

        Args:
            horizons: An iterable sequence of episode lengths.

        Raises:
            ValueError: The length of trajectories in trajs differs from one
                another, or from trajectory lengths in previous calls to this method.
        """
        if self.allow_variable_horizon:  # skip check -- YOLO
            return

        # horizons = all horizons seen so far (including trajs)
        horizons = set(horizons)
        if self._horizon is not None:
            horizons.add(self._horizon)

        if len(horizons) > 1:
            raise ValueError(
                f"Episodes of different length detected: {horizons}. "
                "Variable horizon environments are discouraged -- "
                "termination conditions leak information about reward. See"
                "https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html"
                " for more information. If you are SURE you want to run imitation on a "
                "variable horizon task, then please pass in the flag: "
                "`allow_variable_horizon=True`.",
            )
        elif len(horizons) == 1:
            self._horizon = horizons.pop()

    def __getstate__(self):
        state = self.__dict__.copy()
        # logger can't be pickled as it depends on open files
        del state["_logger"]
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        # callee should modify self.logger directly if they want to override this
        self.logger = state.get("_logger") or imit_logger.configure()


TransitionMapping = Mapping[str, Union[np.ndarray, th.Tensor]]
TransitionKind = TypeVar("TransitionKind", bound=types.TransitionsMinimal)
AnyTransitions = Union[
    Iterable[types.Trajectory],
    Iterable[TransitionMapping],
    TransitionKind,
]


class DemonstrationAlgorithm(BaseImitationAlgorithm, Generic[TransitionKind]):
    """An algorithm that learns from demonstration: BC, IRL, etc."""

    def __init__(
        self,
        *,
        demonstrations: Optional[AnyTransitions],
        custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
        allow_variable_horizon: bool = False,
    ):
        """Creates an algorithm that learns from demonstrations.

        Args:
            demonstrations: Demonstrations from an expert (optional). Transitions
                expressed directly as a `types.TransitionsMinimal` object, a sequence
                of trajectories, or an iterable of transition batches (mappings from
                keywords to arrays containing observations, etc).
            custom_logger: Where to log to; if None (default), creates a new logger.
            allow_variable_horizon: If False (default), algorithm will raise an
                exception if it detects trajectories of different length during
                training. If True, overrides this safety check. WARNING: variable
                horizon episodes leak information about the reward via termination
                condition, and can seriously confound evaluation. Read
                https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html
                before overriding this.
        """
        super().__init__(
            custom_logger=custom_logger,
            allow_variable_horizon=allow_variable_horizon,
        )

        if demonstrations is not None:
            self.set_demonstrations(demonstrations)

    @abc.abstractmethod
    def set_demonstrations(self, demonstrations: AnyTransitions) -> None:
        """Sets the demonstration data.

        Changing the demonstration data on-demand can be useful for
        interactive algorithms like DAgger.

        Args:
             demonstrations: Either a Torch `DataLoader`, any other iterator that
                yields dictionaries containing "obs" and "acts" Tensors or NumPy arrays,
                `TransitionKind` instance, or a Sequence of Trajectory objects.
        """

    @property
    @abc.abstractmethod
    def policy(self) -> policies.BasePolicy:
        """Returns a policy imitating the demonstration data."""


class _WrappedDataLoader:
    """Wraps a data loader (batch iterable) and checks for specified batch size."""

    def __init__(
        self,
        data_loader: Iterable[TransitionMapping],
        expected_batch_size: int,
    ):
        """Builds _WrapedDataLoader.

        Args:
            data_loader: The data loader (batch iterable) to wrap.
            expected_batch_size: The batch size to check for.
        """
        self.data_loader = data_loader
        self.expected_batch_size = expected_batch_size

    def __iter__(self):
        """Iterator yielding data from `self.data_loader`, checking `self.expected_batch_size`.

        Yields:
            Identity -- yields same batches as from `self.data_loader`.

        Raises:
            ValueError: `self.data_loader` returns a batch of size not equal to
                `self.expected_batch_size`.
        """
        for batch in self.data_loader:
            if len(batch["obs"]) != self.expected_batch_size:
                raise ValueError(
                    f"Expected batch size {self.expected_batch_size} "
                    f"!= {len(batch['obs'])} = len(batch['obs'])",
                )
            if len(batch["acts"]) != self.expected_batch_size:
                raise ValueError(
                    f"Expected batch size {self.expected_batch_size} "
                    f"!= {len(batch['acts'])} = len(batch['acts'])",
                )
            yield batch


def make_data_loader(
    transitions: AnyTransitions,
    batch_size: int,
    data_loader_kwargs: Optional[Mapping[str, Any]] = None,
) -> Iterable[TransitionMapping]:
    """Converts demonstration data to Torch data loader.

    Args:
        transitions: Transitions expressed directly as a `types.TransitionsMinimal`
            object, a sequence of trajectories, or an iterable of transition
            batches (mappings from keywords to arrays containing observations, etc).
        batch_size: The size of the batch to create. Does not change the batch size
            if `transitions` is already an iterable of transition batches.
        data_loader_kwargs: Arguments to pass to `th_data.DataLoader`.

    Returns:
        An iterable of transition batches.

    Raises:
        ValueError: if `transitions` is an iterable over transition batches with batch
            size not equal to `batch_size`; or if `transitions` is transitions or a
            sequence of trajectories with total timesteps less than `batch_size`.
        TypeError: if `transitions` is an unsupported type.
    """
    if batch_size <= 0:
        raise ValueError(f"batch_size={batch_size} must be positive.")

    if isinstance(transitions, Iterable):
        try:
            first_item = next(iter(transitions))
        except StopIteration:
            first_item = None
        if isinstance(first_item, types.Trajectory):
            transitions = rollout.flatten_trajectories(list(transitions))

    if isinstance(transitions, types.TransitionsMinimal):
        if len(transitions) < batch_size:
            raise ValueError(
                f"Number of transitions in `demonstrations` {len(transitions)} "
                f"is smaller than batch size {batch_size}.",
            )

        extra_kwargs = dict(shuffle=True, drop_last=True)
        if data_loader_kwargs is not None:
            extra_kwargs.update(data_loader_kwargs)
        return th_data.DataLoader(
            transitions,
            batch_size=batch_size,
            collate_fn=types.transitions_collate_fn,
            **extra_kwargs,
        )
    elif isinstance(transitions, Iterable):
        return _WrappedDataLoader(transitions, batch_size)
    else:
        raise TypeError(f"`demonstrations` unexpected type {type(transitions)}")
